{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0_xya1nyDHfY",
   "metadata": {
    "id": "0_xya1nyDHfY"
   },
   "source": [
    "<table style=\"width:100%\">\n",
    "<tr>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<font size=\"2\">\n",
    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
    "</font>\n",
    "</td>\n",
    "<td style=\"vertical-align:middle; text-align:left;\">\n",
    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
    "</td>\n",
    "</tr>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "l62zIRRSBy_R",
   "metadata": {
    "id": "l62zIRRSBy_R"
   },
   "source": [
    "# Converting a From-Scratch GPT Architecture to Llama 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aFmxTQbwCUMl",
   "metadata": {
    "id": "aFmxTQbwCUMl"
   },
   "source": [
    "- In this notebook, we convert the original GPT architecture into a Llama 2 model step by step (note the GPT and GPT-2 share the same architecture)\n",
    "- Why not Llama 1 or Llama 3?\n",
    "   - The Llama 1 architecture is similar to Llama 2, except that Llama 2 has a larger context window (which is nice); the Llama 1 weights are not readily available and have more usage restrictions, so it makes more sense to focus on Llama 2\n",
    "   - Regarding Llama 3, I will share a separate notebook to convert Llama 2 to Llama 3 (there are only a few small additional changes)\n",
    "- The explanations are purposefully kept minimal in this notebook not to bloat it unnecessarily and focus on the main code\n",
    "- For more information, please see the Llama 2 paper: [Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)](https://arxiv.org/abs/2307.09288)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ohhMKUWvGm9z",
   "metadata": {
    "id": "ohhMKUWvGm9z"
   },
   "source": [
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt2-to-llama2-llama3.webp?1\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "JBpQwU89ETA1",
   "metadata": {
    "id": "JBpQwU89ETA1"
   },
   "source": [
    "- Packages that are being used in this notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
    "outputId": "8118963b-3c72-43af-874b-439ffebdc94c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "huggingface_hub version: 0.24.7\n",
      "sentencepiece version: 0.2.0\n",
      "torch version: 2.4.1+cu121\n"
     ]
    }
   ],
   "source": [
    "from importlib.metadata import version\n",
    "\n",
    "pkgs = [\n",
    "    \"huggingface_hub\",  # to download pretrained weights\n",
    "    \"sentencepiece\",    # to implement the tokenizer\n",
    "    \"torch\",            # to implement the model\n",
    "]\n",
    "for p in pkgs:\n",
    "    print(f\"{p} version: {version(p)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "UJJneXpTEg4W",
   "metadata": {
    "id": "UJJneXpTEg4W"
   },
   "source": [
    "&nbsp;\n",
    "# 1. Convert the GPT model implementation step by step"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "v1zpfX2GHBKa",
   "metadata": {
    "id": "v1zpfX2GHBKa"
   },
   "source": [
    "- In this section, we go through the GPT model code from [chapter 4](../../ch04/01_main-chapter-code/ch04.ipynb) and modify it step by step to implement the Llama 2 architecture\n",
    "- Later, we load the original Llama 2 weights shared by Meta AI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f",
   "metadata": {
    "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f"
   },
   "source": [
    "&nbsp;\n",
    "## 1.1 Replace LayerNorm with RMSNorm layer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8b27fc8-23a1-4e0e-a1ea-792e0428e5e6",
   "metadata": {
    "id": "f8b27fc8-23a1-4e0e-a1ea-792e0428e5e6"
   },
   "source": [
    "- First, we replace LayerNorm by Root Mean Square Layer Normalization (RMSNorm)\n",
    "- LayerNorm normalizes inputs using mean and variance, while RMSNorm uses only the root mean square, which improves computational efficiency\n",
    "- The RMSNorm operation is as follows, where $x$ is the input $\\gamma$ is a trainable parameter (vector), and $\\epsilon$ is a small constant to avoid zero-division errors:\n",
    "\n",
    "$$y_i = \\frac{x_i}{\\text{RMS}(x)} \\gamma_i, \\quad \\text{where} \\quad \\text{RMS}(x) = \\sqrt{\\epsilon + \\frac{1}{n} \\sum x_i^2}$$\n",
    "\n",
    "- For more details, please see the paper [Root Mean Square Layer Normalization (2019)](https://arxiv.org/abs/1910.07467)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d7094381-9499-4e9e-93f9-b79470da3771",
   "metadata": {
    "id": "d7094381-9499-4e9e-93f9-b79470da3771"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "#####################################\n",
    "# Chapter 4\n",
    "#####################################\n",
    "\n",
    "# class LayerNorm(nn.Module):\n",
    "#     def __init__(self, emb_dim):\n",
    "#         super().__init__()\n",
    "#         self.eps = 1e-5\n",
    "#         self.scale = nn.Parameter(torch.ones(emb_dim))\n",
    "#         self.shift = nn.Parameter(torch.zeros(emb_dim))\n",
    "\n",
    "#     def forward(self, x):\n",
    "#         mean = x.mean(dim=-1, keepdim=True)\n",
    "#         var = x.var(dim=-1, keepdim=True, unbiased=False)\n",
    "#         norm_x = (x - mean) / torch.sqrt(var + self.eps)\n",
    "#         return self.scale * norm_x + self.shift\n",
    "\n",
    "\n",
    "class RMSNorm(nn.Module):\n",
    "    def __init__(self, emb_dim, eps=1e-5):\n",
    "        super().__init__()\n",
    "        self.eps = eps\n",
    "        self.emb_dim = emb_dim\n",
    "        self.weight = nn.Parameter(torch.ones(emb_dim)).float()\n",
    "\n",
    "    def forward(self, x):\n",
    "        means = x.pow(2).mean(dim=-1, keepdim=True)\n",
    "        x_normed = x * torch.rsqrt(means + self.eps)\n",
    "        return (x_normed * self.weight).to(dtype=x.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "mtWC8DOmIu0F",
   "metadata": {
    "id": "mtWC8DOmIu0F"
   },
   "source": [
    "- The following code cell checks that this implementation works the same as PyTorch's built-in implementation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e41ade7a-bf06-48b1-8b7e-0e4037d5753f",
   "metadata": {
    "id": "e41ade7a-bf06-48b1-8b7e-0e4037d5753f"
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "example_batch = torch.randn(2, 3, 4)\n",
    "\n",
    "rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])\n",
    "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)\n",
    "\n",
    "assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5eb81f83-c38c-46a4-b763-aa630a32e357",
   "metadata": {
    "id": "5eb81f83-c38c-46a4-b763-aa630a32e357"
   },
   "source": [
    "&nbsp;\n",
    "## 1.2 Replace GELU with SiLU activation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b8aa702-f118-4ff6-9135-90725ec8756c",
   "metadata": {
    "id": "0b8aa702-f118-4ff6-9135-90725ec8756c"
   },
   "source": [
    "- Llama uses the SiLU activation function (instead of GELU), which is also known as the Swish function:\n",
    "\n",
    "$$\n",
    "\\text{silu}(x) = x \\cdot \\sigma(x), \\quad \\text{where} \\quad \\sigma(x) \\text{ is the logistic sigmoid.}\n",
    "$$\n",
    "\n",
    "- For more information, see the SiLU paper: [Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning (2017)](https://arxiv.org/abs/1702.03118)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a74f3757-c634-4a3a-a8f3-6334cde454fe",
   "metadata": {
    "id": "a74f3757-c634-4a3a-a8f3-6334cde454fe"
   },
   "outputs": [],
   "source": [
    "#####################################\n",
    "# Chapter 4\n",
    "#####################################\n",
    "\n",
    "# class GELU(nn.Module):\n",
    "#     def __init__(self):\n",
    "#         super().__init__()\n",
    "\n",
    "#     def forward(self, x):\n",
    "#         return 0.5 * x * (1 + torch.tanh(\n",
    "#             torch.sqrt(torch.tensor(2.0 / torch.pi)) *\n",
    "#             (x + 0.044715 * torch.pow(x, 3))\n",
    "#         ))\n",
    "\n",
    "\n",
    "class SiLU(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SiLU, self).__init__()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return x * torch.sigmoid(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "72ecbe2e-b6b7-4319-972b-1a7fefa3368c",
   "metadata": {
    "id": "72ecbe2e-b6b7-4319-972b-1a7fefa3368c"
   },
   "outputs": [],
   "source": [
    "silu = SiLU()\n",
    "\n",
    "assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f9b5167-1da9-46c8-9964-8036b3b1deb9",
   "metadata": {
    "id": "4f9b5167-1da9-46c8-9964-8036b3b1deb9"
   },
   "source": [
    "&nbsp;\n",
    "## 1.3 Update the FeedForward module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a381e7a-b807-472e-91c9-3e4e3fc5ad91",
   "metadata": {
    "id": "3a381e7a-b807-472e-91c9-3e4e3fc5ad91"
   },
   "source": [
    "- In fact, Llama uses a \"Gates Linear Unit\" (GLU) variant of SiLU called SwiGLU, which essentially results in a slightly differently structured `FeedForward` module\n",
    "- SwiGLU uses a gating mechanism in the feedforward layer, with the formula:\n",
    "\n",
    "$$\\text{SwiGLU}(x) = \\text{SiLU}(\\text{Linear}_1(x)) * (\\text{Linear}_2(x))$$\n",
    "\n",
    "- Here, $\\text{Linear}_1$ and $\\text{Linear}_2$ are two linear layers, and $*$ denotes element-wise multiplication\n",
    "- The third linear layer, $\\text{Linear}_3$, is applied after this gated activation\n",
    "\n",
    "- For more information, see SwiGLU paper: [GLU Variants Improve Transformer (2020)](https://arxiv.org/abs/2002.05202)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d25fbe3d-b7c9-4772-ad67-bc0527e4e20a",
   "metadata": {
    "id": "d25fbe3d-b7c9-4772-ad67-bc0527e4e20a"
   },
   "outputs": [],
   "source": [
    "#####################################\n",
    "# Chapter 4\n",
    "#####################################\n",
    "# class FeedForward(nn.Module):\n",
    "#     def __init__(self, cfg):\n",
    "#         super().__init__()\n",
    "#         self.layers = nn.Sequential(\n",
    "#             nn.Linear(cfg[\"emb_dim\"], 4 * cfg[\"emb_dim\"]),\n",
    "#             GELU(),\n",
    "#             nn.Linear(4 * cfg[\"emb_dim\"], cfg[\"emb_dim\"]),\n",
    "#         )\n",
    "\n",
    "#     def forward(self, x):\n",
    "#         return self.layers(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "477568cb-03cd-4510-b663-a42ce3ec64a2",
   "metadata": {
    "id": "477568cb-03cd-4510-b663-a42ce3ec64a2"
   },
   "outputs": [],
   "source": [
    "class FeedForward(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
    "        self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
    "        self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
    "        self.silu = SiLU()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x_fc1 = self.fc1(x)\n",
    "        x_fc2 = self.fc2(x)\n",
    "        x = self.silu(x_fc1) * x_fc2\n",
    "        return self.fc3(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "qcD8LSHNhBRW",
   "metadata": {
    "id": "qcD8LSHNhBRW"
   },
   "source": [
    "- Note that we also added a `dtype=cfg[\"dtype\"]` setting above, which will allow us to load the model directly in lower precision formats later to reduce memory usage (versus instantiating it in the original 32-bit precision format and then converting it)\n",
    "- We also set `bias=False` since Llama doesn't use any bias units"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6b7bf4f-99d0-42c1-807c-5074d2cc1949",
   "metadata": {
    "id": "f6b7bf4f-99d0-42c1-807c-5074d2cc1949"
   },
   "source": [
    "&nbsp;\n",
    "## 1.4 Implement RoPE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3487a6f-0373-49d8-b2eb-f8ee05d42884",
   "metadata": {
    "id": "d3487a6f-0373-49d8-b2eb-f8ee05d42884"
   },
   "source": [
    "- In the GPT model, the positional embeddings are implemented as follows:\n",
    "\n",
    "```python\n",
    "self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n",
    "```\n",
    "\n",
    "- Unlike traditional absolute positional embeddings, Llama uses rotary position embeddings (RoPE), which enable it to capture both absolute and relative positional information simultaneously\n",
    "- The reference paper for RoPE is [RoFormer: Enhanced Transformer with Rotary Position Embedding (2021)](https://arxiv.org/abs/2104.09864)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a34180fb-448f-44e9-a244-0c736051687b",
   "metadata": {
    "id": "a34180fb-448f-44e9-a244-0c736051687b"
   },
   "outputs": [],
   "source": [
    "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):\n",
    "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
    "\n",
    "    # Compute the inverse frequencies\n",
    "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
    "\n",
    "    # Generate position indices\n",
    "    positions = torch.arange(context_length)\n",
    "\n",
    "    # Compute the angles\n",
    "    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)\n",
    "\n",
    "    # Expand angles to match the head_dim\n",
    "    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)\n",
    "\n",
    "    # Precompute sine and cosine\n",
    "    cos = torch.cos(angles)\n",
    "    sin = torch.sin(angles)\n",
    "\n",
    "    return cos, sin\n",
    "\n",
    "def compute_rope(x, cos, sin):\n",
    "    # x: (batch_size, num_heads, seq_len, head_dim)\n",
    "    batch_size, num_heads, seq_len, head_dim = x.shape\n",
    "    assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
    "\n",
    "    # Split x into first half and second half\n",
    "    x1 = x[..., : head_dim // 2]  # First half\n",
    "    x2 = x[..., head_dim // 2 :]  # Second half\n",
    "\n",
    "    # Adjust sin and cos shapes\n",
    "    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)\n",
    "    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
    "\n",
    "    # Apply the rotary transformation\n",
    "    rotated = torch.cat((-x2, x1), dim=-1)\n",
    "    x_rotated = (x * cos) + (rotated * sin)\n",
    "\n",
    "    return x_rotated.to(dtype=x.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e841b8e-75aa-49db-b1a7-d5c2116dc299",
   "metadata": {
    "id": "8e841b8e-75aa-49db-b1a7-d5c2116dc299"
   },
   "source": [
    "- The following is an example of applying RoPE to the `q` and `k` tensors:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8c89f022-7167-4001-8c21-8e012878733f",
   "metadata": {
    "id": "8c89f022-7167-4001-8c21-8e012878733f"
   },
   "outputs": [],
   "source": [
    "# Settings\n",
    "batch_size = 2\n",
    "context_len = 5\n",
    "num_heads = 4\n",
    "head_dim = 16\n",
    "\n",
    "# Instantiate RoPE parameters\n",
    "cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
    "\n",
    "# Dummy query and key tensors\n",
    "torch.manual_seed(123)\n",
    "queries = torch.randn(batch_size, num_heads, context_len, head_dim)\n",
    "keys = torch.randn(batch_size, num_heads, context_len, head_dim)\n",
    "\n",
    "# Apply rotary position embeddings\n",
    "queries_rot = compute_rope(queries, cos, sin)\n",
    "keys_rot = compute_rope(keys, cos, sin)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f78127b0-dda2-4c5a-98dd-bae8f5fe8297",
   "metadata": {
    "id": "f78127b0-dda2-4c5a-98dd-bae8f5fe8297"
   },
   "source": [
    "&nbsp;\n",
    "## 1.5 Add RoPE to MultiHeadAttention module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "RnmSHROLhhR3",
   "metadata": {
    "id": "RnmSHROLhhR3"
   },
   "source": [
    "- It's important to note that GPT applies the positional embeddings to the inputs, whereas Llama applies rotations to the query and key vectors in the self-attention mechanism itself\n",
    "- Here, we modify the `MultiHeadAttention` class with the appropriate RoPE code\n",
    "- In addition, we remove the `qkv_bias` option and hardcode the `bias=False` setting\n",
    "- Also, we add a dtype setting to be able to instantiate the model with a lower precision later\n",
    " - Tip: since the `TransformerBlock`s (in the next section) are repeated exactly, we could simplify the code and only initialize the buffers once instead for each `MultiHeadAttention` module; however, we add the precomputed RoPE parameters to the `MultiHeadAttention` class so that it can function as a standalone module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d81a441e-0b79-4a8b-8291-ea7f55d58c84",
   "metadata": {
    "id": "d81a441e-0b79-4a8b-8291-ea7f55d58c84"
   },
   "outputs": [],
   "source": [
    "#####################################\n",
    "# Chapter 3\n",
    "#####################################\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_in, d_out, context_length, num_heads, dtype=None):  # ,dropout, num_heads, qkv_bias=False):\n",
    "        super().__init__()\n",
    "        assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
    "\n",
    "        self.d_out = d_out\n",
    "        self.num_heads = num_heads\n",
    "        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim\n",
    "\n",
    "        ################################### NEW ###################################\n",
    "        # Set bias=False and dtype=dtype for all linear layers below\n",
    "        ###########################################################################\n",
    "        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
    "        self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
    "        self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
    "        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)  # Linear layer to combine head outputs\n",
    "        # self.dropout = nn.Dropout(dropout)\n",
    "        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
    "\n",
    "        ################################### NEW ###################################\n",
    "        cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
    "        self.register_buffer(\"cos\", cos)\n",
    "        self.register_buffer(\"sin\", sin)\n",
    "        ###########################################################################\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        b, num_tokens, d_in = x.shape\n",
    "\n",
    "        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)\n",
    "        queries = self.W_query(x)\n",
    "        values = self.W_value(x)\n",
    "\n",
    "        # We implicitly split the matrix by adding a `num_heads` dimension\n",
    "        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
    "        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
    "        values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
    "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
    "\n",
    "        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
    "        keys = keys.transpose(1, 2)\n",
    "        queries = queries.transpose(1, 2)\n",
    "        values = values.transpose(1, 2)\n",
    "\n",
    "        ################################### NEW ###################################\n",
    "        keys = compute_rope(keys, self.cos, self.sin)\n",
    "        queries = compute_rope(queries, self.cos, self.sin)\n",
    "        ###########################################################################\n",
    "\n",
    "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
    "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
    "\n",
    "        # Original mask truncated to the number of tokens and converted to boolean\n",
    "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
    "\n",
    "        # Use the mask to fill attention scores\n",
    "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
    "\n",
    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
    "        # attn_weights = self.dropout(attn_weights)\n",
    "\n",
    "        # Shape: (b, num_tokens, num_heads, head_dim)\n",
    "        context_vec = (attn_weights @ values).transpose(1, 2)\n",
    "\n",
    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
    "        context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n",
    "        context_vec = self.out_proj(context_vec)  # optional projection\n",
    "\n",
    "        return context_vec"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "-lt9SfnVioB3",
   "metadata": {
    "id": "-lt9SfnVioB3"
   },
   "source": [
    "- Below is an example using the `MultiHeadAttention` module on an example input:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "03f15755-0083-483f-963b-99b599651638",
   "metadata": {
    "id": "03f15755-0083-483f-963b-99b599651638"
   },
   "outputs": [],
   "source": [
    "# Settings\n",
    "batch_size = 1\n",
    "context_len = 100\n",
    "max_context_len = 4096\n",
    "embed_dim = 128\n",
    "num_heads = 4\n",
    "\n",
    "\n",
    "example_batch = torch.randn((batch_size, context_len, embed_dim))\n",
    "\n",
    "mha = MultiHeadAttention(\n",
    "    d_in=embed_dim,\n",
    "    d_out=embed_dim,\n",
    "    context_length=max_context_len,\n",
    "    num_heads=num_heads\n",
    ")\n",
    "\n",
    "mha(example_batch)\n",
    "\n",
    "del mha  # delete to free up memory"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5a1a272-a038-4b8f-aaaa-f4b241e7f23f",
   "metadata": {
    "id": "e5a1a272-a038-4b8f-aaaa-f4b241e7f23f"
   },
   "source": [
    "&nbsp;\n",
    "## 1.6 Update the TransformerBlock module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "255f70ac-9c2e-4328-8af7-1c298b8d4a18",
   "metadata": {
    "id": "255f70ac-9c2e-4328-8af7-1c298b8d4a18"
   },
   "source": [
    "- At this stage, most of the hard work is already done; we can now update the `TransformerBlock` to use the code we implemented above\n",
    "- This means we\n",
    " - replace LayerNorm with RMSNorm\n",
    " - remove dropout\n",
    " - remove the `qkv_bias` setting\n",
    " - add the `dtype` setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2e110721-bf2b-42b3-989a-1635b1658af0",
   "metadata": {
    "id": "2e110721-bf2b-42b3-989a-1635b1658af0"
   },
   "outputs": [],
   "source": [
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.att = MultiHeadAttention(\n",
    "            d_in=cfg[\"emb_dim\"],\n",
    "            d_out=cfg[\"emb_dim\"],\n",
    "            context_length=cfg[\"context_length\"],\n",
    "            num_heads=cfg[\"n_heads\"],\n",
    "            dtype=cfg[\"dtype\"]  # NEW\n",
    "            # dropout=cfg[\"drop_rate\"],\n",
    "            # qkv_bias=cfg[\"qkv_bias\"]\n",
    "        )\n",
    "        self.ff = FeedForward(cfg)\n",
    "\n",
    "        ################################### NEW ###################################\n",
    "        # self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n",
    "        # self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n",
    "        self.norm1 = RMSNorm(cfg[\"emb_dim\"])\n",
    "        self.norm2 = RMSNorm(cfg[\"emb_dim\"])\n",
    "        ###########################################################################\n",
    "\n",
    "        # self.drop_shortcut = nn.Dropout(cfg[\"drop_rate\"])\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Shortcut connection for attention block\n",
    "        shortcut = x\n",
    "        x = self.norm1(x)\n",
    "        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]\n",
    "        # x = self.drop_shortcut(x)\n",
    "        x = x + shortcut  # Add the original input back\n",
    "\n",
    "        # Shortcut connection for feed-forward block\n",
    "        shortcut = x\n",
    "        x = self.norm2(x)\n",
    "        x = self.ff(x)\n",
    "        # x = self.drop_shortcut(x)\n",
    "        x = x + shortcut  # Add the original input back\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ada953bc-e2c0-4432-a32d-3f7efa3f6e0f",
   "metadata": {
    "id": "ada953bc-e2c0-4432-a32d-3f7efa3f6e0f"
   },
   "source": [
    "&nbsp;\n",
    "## 1.7 Update the model class"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba5d991a-559b-47be-96f4-31b881ab2da8",
   "metadata": {
    "id": "ba5d991a-559b-47be-96f4-31b881ab2da8"
   },
   "source": [
    "- As you may recall from [chapter 5](../01_main-chapter-code/ch05.ipynb), the `TransformerBlock` is a repeated block within the main model\n",
    "- Our Llama model is almost complete; we just have to update the model code surrounding the `TransformerBlock`\n",
    "- This means we\n",
    "  - remove absolute positional embeddings since we have RoPE embeddings now\n",
    "  - replace LayerNorm with RMSNorm\n",
    "  - remove dropout\n",
    "  - add the dtype setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cf8240fe-5d7f-4e7e-b1ac-e0755aab5e79",
   "metadata": {
    "id": "cf8240fe-5d7f-4e7e-b1ac-e0755aab5e79"
   },
   "outputs": [],
   "source": [
    "# class GPTModel(nn.Module):\n",
    "class Llama2Model(nn.Module):\n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
    "        # self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n",
    "        # self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n",
    "\n",
    "        self.trf_blocks = nn.Sequential(\n",
    "            *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
    "\n",
    "        ################################### NEW ###################################\n",
    "        # self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n",
    "        self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
    "        ###########################################################################\n",
    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
    "\n",
    "    def forward(self, in_idx):\n",
    "        # batch_size, seq_len = in_idx.shape\n",
    "        tok_embeds = self.tok_emb(in_idx)\n",
    "        # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
    "        x = tok_embeds  # + pos_embeds  # Shape [batch_size, num_tokens, emb_size]\n",
    "        # x = self.drop_emb(x)\n",
    "        x = self.trf_blocks(x)\n",
    "        x = self.final_norm(x)\n",
    "        logits = self.out_head(x)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60",
   "metadata": {
    "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60"
   },
   "source": [
    "&nbsp;\n",
    "## 2. Initialize model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bG--zY-Ljj1f",
   "metadata": {
    "id": "bG--zY-Ljj1f"
   },
   "source": [
    "- The model code is now complete, and we are ready to initialize it\n",
    "- In [chapter 5](../01_main-chapter-code/ch05.ipynb), we used the following config file to specify the 124M-parameter GPT model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4b7428df-3d02-4ccd-97b5-a629bdabbe8f",
   "metadata": {
    "id": "4b7428df-3d02-4ccd-97b5-a629bdabbe8f"
   },
   "outputs": [],
   "source": [
    "GPT_CONFIG_124M = {\n",
    "    \"vocab_size\": 50257,     # Vocabulary size\n",
    "    \"context_length\": 1024,  # Context length\n",
    "    \"emb_dim\": 768,          # Embedding dimension\n",
    "    \"n_heads\": 12,           # Number of attention heads\n",
    "    \"n_layers\": 12,          # Number of layers\n",
    "    \"drop_rate\": 0.1,        # Dropout rate\n",
    "    \"qkv_bias\": False        # Query-Key-Value bias\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bVi8uiBjw2T",
   "metadata": {
    "id": "8bVi8uiBjw2T"
   },
   "source": [
    "- For reference, the 1.5B parameter GPT model config is shown below as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "tAOojV_mkEnd",
   "metadata": {
    "id": "tAOojV_mkEnd"
   },
   "outputs": [],
   "source": [
    "GPT_CONFIG_1558M = {\n",
    "    \"vocab_size\": 50257,     # Vocabulary size\n",
    "    \"context_length\": 1024,  # Context length\n",
    "    \"emb_dim\": 1600,         # Embedding dimension\n",
    "    \"n_heads\": 25,           # Number of attention heads\n",
    "    \"n_layers\": 48,          # Number of layers\n",
    "    \"drop_rate\": 0.1,        # Dropout rate\n",
    "    \"qkv_bias\": False        # Query-Key-Value bias\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "HoGGRAGykQTE",
   "metadata": {
    "id": "HoGGRAGykQTE"
   },
   "source": [
    "- Similarly, we can define a Llama 2 config file for the 7B model (we ignore the other larger models for simplicity here):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18",
   "metadata": {
    "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18"
   },
   "outputs": [],
   "source": [
    "LLAMA2_CONFIG_7B = {\n",
    "    \"vocab_size\": 32000,     # Vocabulary size\n",
    "    \"context_length\": 4096,  # Context length\n",
    "    \"emb_dim\": 4096,         # Embedding dimension\n",
    "    \"n_heads\": 32,           # Number of attention heads\n",
    "    \"n_layers\": 32,          # Number of layers\n",
    "    \"hidden_dim\": 11008,     # NEW: Size of the intermediate dimension in FeedForward\n",
    "    \"dtype\": torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "FAP7fiBzkaBz",
   "metadata": {
    "id": "FAP7fiBzkaBz"
   },
   "source": [
    "- Using these settings, we can now initialize a Llama 2 7B model (note that this requires ~26 GB of memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "7004d785-ac9a-4df5-8760-6807fc604686",
   "metadata": {
    "id": "7004d785-ac9a-4df5-8760-6807fc604686"
   },
   "outputs": [],
   "source": [
    "model = Llama2Model(LLAMA2_CONFIG_7B)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6079f747-8f20-4c6b-8d38-7156f1101729",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6079f747-8f20-4c6b-8d38-7156f1101729",
    "outputId": "0a0eb34b-1a21-4c11-804f-b40007bda5a3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of parameters: 6,738,415,616\n"
     ]
    }
   ],
   "source": [
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "print(f\"Total number of parameters: {total_params:,}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "Bx14NtzWk2wj",
   "metadata": {
    "id": "Bx14NtzWk2wj"
   },
   "source": [
    "- As shown above, the model contains 6.7 billion parameters (commonly rounded and referred to as a 7B model)\n",
    "- Additionally, we can calculate the memory requirements for this model using the code below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
    "outputId": "11ced939-556d-4511-d5c0-10a94ed3df32"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "float32 (PyTorch default): 52.33 GB\n",
      "bfloat16: 26.17 GB\n"
     ]
    }
   ],
   "source": [
    "def model_memory_size(model, input_dtype=torch.float32):\n",
    "    total_params = 0\n",
    "    total_grads = 0\n",
    "    for param in model.parameters():\n",
    "        # Calculate total number of elements per parameter\n",
    "        param_size = param.numel()\n",
    "        total_params += param_size\n",
    "        # Check if gradients are stored for this parameter\n",
    "        if param.requires_grad:\n",
    "            total_grads += param_size\n",
    "\n",
    "    # Calculate buffer size (non-parameters that require memory)\n",
    "    total_buffers = sum(buf.numel() for buf in model.buffers())\n",
    "\n",
    "    # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
    "    # We assume parameters and gradients are stored in the same type as input dtype\n",
    "    element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
    "    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
    "\n",
    "    # Convert bytes to gigabytes\n",
    "    total_memory_gb = total_memory_bytes / (1024**3)\n",
    "\n",
    "    return total_memory_gb\n",
    "\n",
    "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
    "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "zudd-5PulKFL",
   "metadata": {
    "id": "zudd-5PulKFL"
   },
   "source": [
    "- Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d",
   "metadata": {
    "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d"
   },
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "elif torch.backends.mps.is_available():\n",
    "    device = torch.device(\"mps\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "model.to(device);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34",
   "metadata": {
    "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34"
   },
   "source": [
    "&nbsp;\n",
    "## 3. Load tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005",
   "metadata": {
    "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005"
   },
   "source": [
    "- In this section, we are going to load the tokenizer for the model\n",
    "- Llama 2 uses Google's [SentencePiece](https://github.com/google/sentencepiece) tokenizer instead of OpenAI's [Tiktoken](https://github.com/openai/tiktoken) (but Llama 3 uses Tiktoken)\n",
    "- Meta AI shared the original Llama 2 model weights and tokenizer vocabulary on the Hugging Face Hub\n",
    "- We will download the tokenizer vocabulary from the Hub and load it into SentencePiece\n",
    "- Uncomment and run the following code to install the required libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "768989ea-dc60-4dc8-ae84-cbb3fd224422",
   "metadata": {
    "id": "768989ea-dc60-4dc8-ae84-cbb3fd224422"
   },
   "outputs": [],
   "source": [
    "# !pip install huggingface_hub sentencepiece"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "KbnlzsbYmJU6",
   "metadata": {
    "id": "KbnlzsbYmJU6"
   },
   "source": [
    "- Please note that Meta AI requires that you accept the Llama 2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) repository to accept the terms\n",
    "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n",
    "\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/settings.webp?1\" width=\"300px\">\n",
    "\n",
    "- Then, create and copy the access token so you can copy & paste it into the next code cell\n",
    "\n",
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/access-token.webp?1\" width=\"600px\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3357a230-b678-4691-a238-257ee4e80185",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3357a230-b678-4691-a238-257ee4e80185",
    "outputId": "768ed6af-ce14-40bc-ca18-117b4b448269"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
      "Token is valid (permission: read).\n",
      "Your token has been saved to /root/.cache/huggingface/token\n",
      "Login successful\n"
     ]
    }
   ],
   "source": [
    "from huggingface_hub import login\n",
    "import json\n",
    "\n",
    "with open(\"config.json\", \"r\") as config_file:\n",
    "    config = json.load(config_file)\n",
    "    access_token = config[\"HF_ACCESS_TOKEN\"]\n",
    "\n",
    "login(token=access_token)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "IxGh6ZYQo0VN",
   "metadata": {
    "id": "IxGh6ZYQo0VN"
   },
   "source": [
    "- After login via the access token, which is necessary to verify that we accepted the Llama 2 licensing terms, we can now download the tokenizer vocabulary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 153,
     "referenced_widgets": [
      "e6c75a6aa7b942fe84160e286e3acb3d",
      "08f0bf9459bd425498a5cb236f9d4a72",
      "10251d6f724e43788c41d4b7879cbfd3",
      "53a973c0853b44418698136bd04df039",
      "bdb071e7145a4007ae01599333e72612",
      "6b1821a7f4574e3aba09c1e410cc81e4",
      "8c2873eaec3445888ad3d54ad7387950",
      "0c8f7044966e4207b12352503c67dcbb",
      "8b5951213c9e4798a258146d61d02d11",
      "2c05df3f91e64df7b33905b1065a76f7",
      "742ae5487f2648fcae7ca8e22c7f8db9"
     ]
    },
    "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
    "outputId": "c230fec9-5c71-4a41-90ab-8a34d114ea01"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
      "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
      "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
      "You will be able to reuse this secret in all of your notebooks.\n",
      "Please note that authentication is recommended but still optional to access public models or datasets.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6c75a6aa7b942fe84160e286e3acb3d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import hf_hub_download\n",
    "\n",
    "tokenizer_file = hf_hub_download(\n",
    "    repo_id=\"meta-llama/Llama-2-7b\",\n",
    "    filename=\"tokenizer.model\",\n",
    "    local_dir=\"Llama-2-7b\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "gp7iQ8cXAJLv",
   "metadata": {
    "id": "gp7iQ8cXAJLv"
   },
   "source": [
    "- To provide a more familiar interface for the tokenizer, we define a small `LlamaTokenizer` wrapper class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "Ef4WxhjOBOOc",
   "metadata": {
    "id": "Ef4WxhjOBOOc"
   },
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "\n",
    "\n",
    "class LlamaTokenizer:\n",
    "    def __init__(self, tokenizer_file):\n",
    "        sp = spm.SentencePieceProcessor()\n",
    "        sp.load(tokenizer_file)\n",
    "        self.tokenizer = sp\n",
    "\n",
    "    def encode(self, text):\n",
    "        return self.tokenizer.encode_as_ids(text)\n",
    "\n",
    "    def decode(self, ids):\n",
    "        return self.tokenizer.decode_pieces(ids)\n",
    "\n",
    "\n",
    "tokenizer = LlamaTokenizer(tokenizer_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "NVhmFeX3pT_M",
   "metadata": {
    "id": "NVhmFeX3pT_M"
   },
   "source": [
    "- We can now use the `generate` function to have the Llama 2 model generate new text:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
    "outputId": "acd5065d-8900-4ba8-ef85-968365f3a0cb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output text:\n",
      " Every effort movesαllRadius deletingpretcc否']; future eer napulate lackус während inter DES издаSchéon로жа Bass differencespadxsnu ;; ctx始\n"
     ]
    }
   ],
   "source": [
    "from previous_chapters import generate, text_to_token_ids, token_ids_to_text\n",
    "# If the `previous_chapters.py` file is not available locally,\n",
    "# you can import it from the `llms-from-scratch` PyPI package.\n",
    "# For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
    "# E.g.,\n",
    "# from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text\n",
    "\n",
    "\n",
    "\n",
    "torch.manual_seed(123)\n",
    "\n",
    "token_ids = generate(\n",
    "    model=model,\n",
    "    idx=text_to_token_ids(\"Every effort moves\", tokenizer).to(device),\n",
    "    max_new_tokens=30,\n",
    "    context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
    "    top_k=1,\n",
    "    temperature=0.\n",
    ")\n",
    "\n",
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93WTtAA5paYV",
   "metadata": {
    "id": "93WTtAA5paYV"
   },
   "source": [
    "- Of course, as we can see above, the text is nonsensical since we haven't trained the Llama 2 model yet\n",
    "- In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f63cc248-1d27-4eb6-aa50-173b436652f8",
   "metadata": {
    "id": "f63cc248-1d27-4eb6-aa50-173b436652f8"
   },
   "source": [
    "&nbsp;\n",
    "## 4. Load pretrained weights"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aKeN7rUfqZMI",
   "metadata": {
    "id": "aKeN7rUfqZMI"
   },
   "source": [
    "- We are loading the [\"meta-llama/Llama-2-7b\"](https://huggingface.co/meta-llama/Llama-2-7b) base model below, which is a simple text completion model before finetuning\n",
    "- Alternatively, you can load the instruction-finetuned and aligned [\"meta-llama/Llama-2-7b-chat\"](https://huggingface.co/meta-llama/Llama-2-7b-chat) model by modifying the string in the next code cell accordingly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 49,
     "referenced_widgets": [
      "66e777955e8748df878f118f07f38dab",
      "da89ae3ea4d2474e98f64ada608f3cea",
      "93e6da39c25f4edfaa72056c89df1f7f",
      "b628603e4cb0405398c916587ee96756",
      "93bedcb9245e44a0a1eb7e4155070f66",
      "0723f467d37b4904819a8bb33ebda10f",
      "e54928776bc649339002adced63738b0",
      "d8e0f42068af4cb094e2f115f76e06e0",
      "0a939565b6e94f08bee0a66e0f9827d4",
      "a5fedbb7ec2e43d99711bb4cd84b9486",
      "0c186f6539714d8eab023969ce47c500"
     ]
    },
    "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
    "outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "66e777955e8748df878f118f07f38dab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "consolidated.00.pth:   0%|          | 0.00/13.5G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "weights_file = hf_hub_download(\n",
    "   repo_id=\"meta-llama/Llama-2-7b\",\n",
    "   filename=\"consolidated.00.pth\",\n",
    "   local_dir=\"Llama-2-7b\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e67cca5c-ba4b-4be5-85c7-fdceae8a5701",
   "metadata": {
    "id": "e67cca5c-ba4b-4be5-85c7-fdceae8a5701"
   },
   "outputs": [],
   "source": [
    "weights = torch.load(weights_file, weights_only=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "-15SJ7btq2zE",
   "metadata": {
    "id": "-15SJ7btq2zE"
   },
   "source": [
    "- The `weights` contains the following tensors (only the first 15 are shown for simplicity):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
    "outputId": "fa83d38a-bb41-4cb2-d3c7-c573bfe1f8a4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['tok_embeddings.weight',\n",
       " 'norm.weight',\n",
       " 'output.weight',\n",
       " 'layers.0.attention.wq.weight',\n",
       " 'layers.0.attention.wk.weight',\n",
       " 'layers.0.attention.wv.weight',\n",
       " 'layers.0.attention.wo.weight',\n",
       " 'layers.0.feed_forward.w1.weight',\n",
       " 'layers.0.feed_forward.w2.weight',\n",
       " 'layers.0.feed_forward.w3.weight',\n",
       " 'layers.0.attention_norm.weight',\n",
       " 'layers.0.ffn_norm.weight',\n",
       " 'layers.1.attention.wq.weight',\n",
       " 'layers.1.attention.wk.weight',\n",
       " 'layers.1.attention.wv.weight']"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(weights.keys())[:15]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "UeeSpnunrDFB",
   "metadata": {
    "id": "UeeSpnunrDFB"
   },
   "source": [
    "- The following function, modeled after the `load_weights_into_gpt` function in [chapter 5](../01_main-chapter-code/ch05.ipynb), loads the pretrained weights into our Llama 2 model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
   "metadata": {
    "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
   },
   "outputs": [],
   "source": [
    "def assign(left, right):\n",
    "    if left.shape != right.shape:\n",
    "        raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n",
    "\n",
    "    if isinstance(right, torch.Tensor):\n",
    "        return torch.nn.Parameter(right.clone().detach())\n",
    "    else:\n",
    "        return torch.nn.Parameter(torch.tensor(right))\n",
    "\n",
    "\n",
    "def load_weights_into_llama(model, param_config, params):\n",
    "    model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
    "\n",
    "    for l in range(param_config[\"n_layers\"]):\n",
    "\n",
    "        # Load attention weights\n",
    "        model.trf_blocks[l].att.W_query.weight = assign(\n",
    "            model.trf_blocks[l].att.W_query.weight,\n",
    "            params[f\"layers.{l}.attention.wq.weight\"]\n",
    "        )\n",
    "        model.trf_blocks[l].att.W_key.weight = assign(\n",
    "            model.trf_blocks[l].att.W_key.weight,\n",
    "            params[f\"layers.{l}.attention.wk.weight\"]\n",
    "        )\n",
    "        model.trf_blocks[l].att.W_value.weight = assign(\n",
    "            model.trf_blocks[l].att.W_value.weight,\n",
    "            params[f\"layers.{l}.attention.wv.weight\"]\n",
    "        )\n",
    "        model.trf_blocks[l].att.out_proj.weight = assign(\n",
    "            model.trf_blocks[l].att.out_proj.weight,\n",
    "            params[f\"layers.{l}.attention.wo.weight\"]\n",
    "        )\n",
    "        model.trf_blocks[l].norm1.weight = assign(\n",
    "            model.trf_blocks[l].norm1.weight,\n",
    "            params[f\"layers.{l}.attention_norm.weight\"]\n",
    "        )\n",
    "\n",
    "        # Load FeedForward weights\n",
    "        model.trf_blocks[l].ff.fc1.weight = assign(\n",
    "            model.trf_blocks[l].ff.fc1.weight,\n",
    "            params[f\"layers.{l}.feed_forward.w1.weight\"]\n",
    "        )\n",
    "        # For some reason w2 and w3 are provided in the wrong order in the weights file\n",
    "        model.trf_blocks[l].ff.fc2.weight = assign(\n",
    "            model.trf_blocks[l].ff.fc2.weight,\n",
    "            params[f\"layers.{l}.feed_forward.w3.weight\"]\n",
    "        )\n",
    "        model.trf_blocks[l].ff.fc3.weight = assign(\n",
    "            model.trf_blocks[l].ff.fc3.weight,\n",
    "            params[f\"layers.{l}.feed_forward.w2.weight\"]\n",
    "        )\n",
    "        model.trf_blocks[l].norm2.weight = assign(\n",
    "            model.trf_blocks[l].norm2.weight,\n",
    "            params[f\"layers.{l}.ffn_norm.weight\"]\n",
    "        )\n",
    "\n",
    "    # Load output layer weights\n",
    "    model.final_norm.weight = assign(model.final_norm.weight, params[\"norm.weight\"])\n",
    "    model.out_head.weight = assign(model.out_head.weight, params[\"output.weight\"])\n",
    "\n",
    "\n",
    "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
    "model.to(device);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "TDuv_Us2rNvk",
   "metadata": {
    "id": "TDuv_Us2rNvk"
   },
   "source": [
    "- Next, we are ready to use the model for text generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "240987e8-a023-462e-9376-9edfb27559ec",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "240987e8-a023-462e-9376-9edfb27559ec",
    "outputId": "044f24b3-4018-4860-834d-6c2731b9e47c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output text:\n",
      " Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "\n",
    "token_ids = generate(\n",
    "    model=model,\n",
    "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
    "    max_new_tokens=25,\n",
    "    context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
    "    top_k=1,\n",
    "    temperature=0.\n",
    ")\n",
    "\n",
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d72ed949-b6c0-4966-922f-eb0da732c404",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "## 5. Using the instruction-finetuned model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "akyo7WNyF_YL",
   "metadata": {
    "id": "akyo7WNyF_YL"
   },
   "source": [
    "- As mentioned earlier, above we used the pretrained base model; if you want to use a model capable of following instructions, use the `\"meta-llama/Llama-2-7b-chat\"` model instead, as shown below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "nbvAV7vaz6yc",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 101,
     "referenced_widgets": [
      "3b2448a60f5f4ba5b2c686037c8ecd78",
      "60c5932944f24f5fad1d8da89c8e5ae9",
      "aa31aed1b8854a4281fd7e81c60e1205",
      "d4acf06c2414412f8f2fb4f48981c954",
      "693d69251d3d48219c084af17b54b851",
      "ff36d28c55dd4db3a0f76a87640fdfe2",
      "71c49ef820494d5f8908a3daf39f0755",
      "525dc406534f4369b11208816f8fd0d7",
      "865f39213a7341b68f2fe73caaf801b1",
      "eaf4c0231b6d4993b2f8e9e63d8b6921",
      "a11edf3b018e42c88a63a515cf7fe478"
     ]
    },
    "id": "nbvAV7vaz6yc",
    "outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b2448a60f5f4ba5b2c686037c8ecd78",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "consolidated.00.pth:   0%|          | 0.00/13.5G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output text:\n",
      " What do llamas eat?\n",
      "Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n"
     ]
    }
   ],
   "source": [
    "del model  # to free up memory\n",
    "\n",
    "weights_file = hf_hub_download(\n",
    "   repo_id=\"meta-llama/Llama-2-7b-chat\",\n",
    "   filename=\"consolidated.00.pth\",\n",
    "   local_dir=\"Llama-2-7b-chat\"\n",
    ")\n",
    "\n",
    "model = Llama2Model(LLAMA2_CONFIG_7B)\n",
    "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
    "model.to(device);\n",
    "\n",
    "torch.manual_seed(123)\n",
    "\n",
    "token_ids = generate(\n",
    "    model=model,\n",
    "    idx=text_to_token_ids(\"What do llamas eat?\", tokenizer).to(device),\n",
    "    max_new_tokens=25,\n",
    "    context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
    "    top_k=1,\n",
    "    temperature=0.\n",
    ")\n",
    "\n",
    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f693da1-a07c-4e1d-af5a-c3923525f1e2",
   "metadata": {},
   "source": [
    "&nbsp;\n",
    "# What's next?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fae93739-ca12-46ba-8ca7-7c07c59f669b",
   "metadata": {},
   "source": [
    "- This notebook converted the original GPT-2 architecture into a Llama 2 model\n",
    "- If you are interested in how to convert Llama 2 into Llama 3, Llama 3.1, and Llama 3.2, check out the [converting-llama2-to-llama3.ipynb](converting-llama2-to-llama3.ipynb) notebook"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "A100",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
